import cv2
import numpy as np

from pathlib import Path

from img_deal import img_preprocess


def load_dataset(new_width: int, new_height: int,
                 file_path: str) -> tuple[np.ndarray, np.ndarray]:
    """
    加载数据集

    :param file_path: 数据集路径
    :param new_width: 需要重置的图片宽度
    :param new_height: 需要重置的图片高度
    :return: 数据集和标签的矩阵
    """
    path = Path(file_path)
    data = []
    label = []
    for num, i in enumerate(path.iterdir(), start=1):
        img = cv2.imread(str(i))
        img = img_preprocess(img, new_width=new_width, new_height=new_height)
        label.append([i.name.split('.')[1]])
        data.append(img)
        print(f"\r[info]:正在加载第{num}张图片...", end='', flush=True)
        # if num % 20 == 0:
        #     break
    print()
    return np.array(data, dtype="float"), np.array(label)


if __name__ == '__main__':
    print(load_dataset(32, 32))
